# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from pathlib import Path
from time import sleep

import wget
from lightning.pytorch.plugins.environments import LightningEnvironment
from lightning.pytorch.strategies import DDPStrategy, StrategyRegistry

from nemo.utils import logging


def maybe_download_from_cloud(url, filename, subfolder=None, cache_dir=None, refresh_cache=False) -> str:
    """
    Helper function to download pre-trained weights from the cloud
    Args:
        url: (str) URL of storage
        filename: (str) what to download. The request will be issued to url/filename
        subfolder: (str) subfolder within cache_dir. The file will be stored in cache_dir/subfolder. Subfolder can
            be empty
        cache_dir: (str) a cache directory where to download. If not present, this function will attempt to create it.
            If None (default), then it will be $HOME/.cache/torch/NeMo
        refresh_cache: (bool) if True and cached file is present, it will delete it and re-fetch

    Returns:
        If successful - absolute local path to the downloaded file
        else - empty string
    """
    # try:
    if cache_dir is None:
        cache_location = Path.joinpath(Path.home(), ".cache/torch/NeMo")
    else:
        cache_location = cache_dir
    if subfolder is not None:
        destination = Path.joinpath(cache_location, subfolder)
    else:
        destination = cache_location

    if not os.path.exists(destination):
        os.makedirs(destination, exist_ok=True)

    destination_file = Path.joinpath(destination, filename)

    if os.path.exists(destination_file):
        logging.info(f"Found existing object {destination_file}.")
        if refresh_cache:
            logging.info("Asked to refresh the cache.")
            logging.info(f"Deleting file: {destination_file}")
            os.remove(destination_file)
        else:
            logging.info(f"Re-using file from: {destination_file}")
            return str(destination_file)
    # download file
    wget_uri = url + filename
    logging.info(f"Downloading from: {wget_uri} to {str(destination_file)}")
    # NGC links do not work everytime so we try and wait
    i = 0
    max_attempts = 3
    while i < max_attempts:
        i += 1
        try:
            wget.download(wget_uri, str(destination_file))
            if os.path.exists(destination_file):
                return destination_file
            else:
                return ""
        except:
            logging.info(f"Download from cloud failed. Attempt {i} of {max_attempts}")
            sleep(0.05)
            continue
    raise ValueError("Not able to download url right now, please try again.")


class SageMakerDDPStrategy(DDPStrategy):
    @property
    def cluster_environment(self):
        env = LightningEnvironment()
        env.world_size = lambda: int(os.environ["WORLD_SIZE"])
        env.global_rank = lambda: int(os.environ["RANK"])
        return env

    @cluster_environment.setter
    def cluster_environment(self, env):
        # prevents Lightning from overriding the Environment required for SageMaker
        pass


def initialize_sagemaker() -> None:
    """
    Helper function to initiate sagemaker with NeMo.
    This function installs libraries that NeMo requires for the ASR toolkit + initializes sagemaker ddp.
    """

    StrategyRegistry.register(
        name='smddp',
        strategy=SageMakerDDPStrategy,
        process_group_backend="smddp",
        find_unused_parameters=False,
    )

    def _install_system_libraries() -> None:
        os.system('chmod 777 /tmp && apt-get update && apt-get install -y libsndfile1 ffmpeg')

    def _patch_torch_metrics() -> None:
        """
        Patches torchmetrics to not rely on internal state.
        This is because sagemaker DDP overrides the `__init__` function of the modules to do automatic-partitioning.
        """
        from torchmetrics import Metric

        def __new_hash__(self):
            hash_vals = [self.__class__.__name__, id(self)]
            return hash(tuple(hash_vals))

        Metric.__hash__ = __new_hash__

    _patch_torch_metrics()

    if os.environ.get("RANK") and os.environ.get("WORLD_SIZE"):
        import smdistributed.dataparallel.torch.distributed as dist

        # has to be imported, as it overrides torch modules and such when DDP is enabled.
        import smdistributed.dataparallel.torch.torch_smddp

        dist.init_process_group()

        if dist.get_local_rank():
            _install_system_libraries()
        return dist.barrier()  # wait for main process
    _install_system_libraries()
    return
